{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "collapsed": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loaded rnn.py\n",
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n",
      "(49, 5) (49, 5)\n"
     ]
    }
   ],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import numpy as np\n",
    "import theano\n",
    "import theano.tensor as T\n",
    "from scipy import linalg\n",
    " \n",
    "# number of input units\n",
    "n_in = 5\n",
    "# number of hidden units\n",
    "n_hid = 10\n",
    "# number of output units\n",
    "n_out = 5\n",
    " \n",
    "# Generate sinewaves offset in phase\n",
    "n_timesteps = 50\n",
    "d1 = 3 * np.arange(n_timesteps) / (2 * np.pi)\n",
    "d2 = 3 * np.arange(n_in) / (2 * np.pi)\n",
    "all_sines = np.sin(np.array([d1] * n_in).T + d2)\n",
    " \n",
    "# Setup dataset and initial hidden vector of zeros\n",
    "X = all_sines[:-1]\n",
    "y = all_sines[1:]\n",
    "print X.shape, y.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "import rnn\n",
    "model = rnn.Rnn(n_in,n_hid,n_out, 0.001, single_output=False)\n",
    "\n",
    "# RNN uses a mean squared error and a linear output activation. Feel free to change it in rnn.py"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {
    "collapsed": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0: err 123.894156\n",
      "Epoch 40: err 0.326533\n",
      "Epoch 80: err 0.166228\n",
      "Epoch 120: err 0.124947\n",
      "Epoch 160: err 0.096020\n",
      "Epoch 200: err 0.075262\n",
      "Epoch 240: err 0.060122\n",
      "Epoch 280: err 0.048925\n",
      "Epoch 320: err 0.040530\n",
      "Epoch 360: err 0.034145\n",
      "Epoch 400: err 0.029218\n",
      "Epoch 440: err 0.025358\n",
      "Epoch 480: err 0.022291\n",
      "Epoch 520: err 0.019819\n",
      "Epoch 560: err 0.017802\n",
      "Epoch 600: err 0.016136\n",
      "Epoch 640: err 0.014745\n",
      "Epoch 680: err 0.013573\n",
      "Epoch 720: err 0.012576\n",
      "Epoch 760: err 0.011721\n",
      "Epoch 800: err 0.010982\n",
      "Epoch 840: err 0.010340\n",
      "Epoch 880: err 0.009777\n",
      "Epoch 920: err 0.009281\n",
      "Epoch 960: err 0.008841\n",
      "Epoch 1000: err 0.008448\n",
      "Epoch 1040: err 0.008096\n",
      "Epoch 1080: err 0.007779\n",
      "Epoch 1120: err 0.007491\n",
      "Epoch 1160: err 0.007228\n",
      "Epoch 1200: err 0.006988\n",
      "Epoch 1240: err 0.006767\n",
      "Epoch 1280: err 0.006562\n",
      "Epoch 1320: err 0.006372\n",
      "Epoch 1360: err 0.006195\n",
      "Epoch 1400: err 0.006030\n",
      "Epoch 1440: err 0.005874\n",
      "Epoch 1480: err 0.005728\n",
      "Epoch 1520: err 0.005590\n",
      "Epoch 1560: err 0.005459\n",
      "Epoch 1600: err 0.005335\n",
      "Epoch 1640: err 0.005216\n",
      "Epoch 1680: err 0.005104\n",
      "Epoch 1720: err 0.004996\n",
      "Epoch 1760: err 0.004893\n",
      "Epoch 1800: err 0.004794\n",
      "Epoch 1840: err 0.004699\n",
      "Epoch 1880: err 0.004607\n",
      "Epoch 1920: err 0.004519\n",
      "Epoch 1960: err 0.004434\n",
      "Epoch 1999: err 0.004354\n"
     ]
    }
   ],
   "source": [
    "epochs = 2000\n",
    "# Print 50 status updates along with last\n",
    "status_points = list(range(epochs))\n",
    "status_points = status_points[::epochs // 50] + [status_points[-1]]\n",
    "for i in range(epochs):    \n",
    "    pred = model.predictions(X)\n",
    "    err = model.train(X, y)\n",
    "    if i in status_points:\n",
    "        print(\"Epoch %i: err %f\" % (i, err))\n",
    " "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "plt.plot(pred, color=\"darkred\")\n",
    "plt.plot(y, color=\"steelblue\")\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
